import random
import sys
import torch
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
import pandas as pd

from torch_geometric.utils import to_undirected, add_remaining_self_loops

def load_data(dataset_name, dataset_folder='/data/shared/zhexu/'):
    if dataset_name in ['ogbn-arxiv']:
        dataset = PygNodePropPredDataset(name=dataset_name, root=dataset_folder)
        data = dataset[0]
        idx_splits = dataset.get_idx_split()
        data.train_mask = torch.zeros(data.num_nodes).bool()
        data.val_mask = torch.zeros(data.num_nodes).bool()
        data.test_mask = torch.zeros(data.num_nodes).bool()
        data.train_mask[idx_splits['train']] = True
        data.val_mask[idx_splits['valid']] = True
        data.test_mask[idx_splits['test']] = True
        num_classes = 40

    elif dataset_name in ['cora', 'pubmed']:
        dataset = Planetoid(dataset_folder, dataset_name, transform=T.NormalizeFeatures())
        data = dataset[0]
        num_classes = 3 if dataset_name == 'pubmed' else 7

    elif dataset_name == 'ogbn-products':
        data = torch.load(f'{dataset_folder}ogbn_products/ogbn-products_subset.pt')
        # text = pd.read_csv(f'{dataset_folder}ogbn_products_orig/ogbn-products_subset.csv')
        # text = [f'Product:{ti}; Description: {cont}\n'for ti,
        #         cont in zip(text['title'], text['content'])]

        data.edge_index = data.adj_t.to_symmetric()
        row, col, _ = data.edge_index.coo()
        edge_list = torch.stack([row, col], dim=0)
        data.edge_index = edge_list
        num_classes = 47
    
    data.edge_index = to_undirected(data.edge_index, data.num_nodes)
    data.edge_index, _ = add_remaining_self_loops(data.edge_index, num_nodes=data.num_nodes)
    
    return num_classes, data